{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural text generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from seq2seq import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = Config().data_path()/'giga-fren'\n",
    "data = load_data(path)\n",
    "model_path = Config().model_path()\n",
    "emb_enc = torch.load(model_path/'fr_emb.pth')\n",
    "emb_dec = torch.load(model_path/'en_emb.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2SeqRNN_attn(nn.Module):\n",
    "    def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):\n",
    "        super().__init__()\n",
    "        self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1\n",
    "        self.bos_idx,self.pad_idx = bos_idx,pad_idx\n",
    "        self.emb_enc,self.emb_dec = emb_enc,emb_dec\n",
    "        self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim\n",
    "        self.voc_sz_dec = emb_dec.num_embeddings\n",
    "                 \n",
    "        self.emb_enc_drop = nn.Dropout(0.15)\n",
    "        self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, \n",
    "                              batch_first=True, bidirectional=True)\n",
    "        self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)\n",
    "        self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,\n",
    "                              dropout=0.1, batch_first=True)\n",
    "        self.out_drop = nn.Dropout(0.35)\n",
    "        self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)\n",
    "        self.out.weight.data = self.emb_dec.weight.data\n",
    "        \n",
    "        self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False)\n",
    "        self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)\n",
    "        self.V =  self.init_param(self.emb_sz_dec)\n",
    "        \n",
    "    def encoder(self, bs, inp):\n",
    "        h = self.initHidden(bs)\n",
    "        emb = self.emb_enc_drop(self.emb_enc(inp))\n",
    "        enc_out, hid = self.gru_enc(emb, 2*h)\n",
    "        pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()\n",
    "        pre_hid = pre_hid.view(self.nl, bs, 2*self.nh)\n",
    "        hid = self.out_enc(pre_hid)\n",
    "        return hid,enc_out\n",
    "    \n",
    "    def decoder(self, dec_inp, hid, enc_att, enc_out):\n",
    "        hid_att = self.hid_att(hid[-1])\n",
    "        u = torch.tanh(enc_att + hid_att[:,None])\n",
    "        attn_wgts = F.softmax(u @ self.V, 1)\n",
    "        ctx = (attn_wgts[...,None] * enc_out).sum(1)\n",
    "        emb = self.emb_dec(dec_inp)\n",
    "        outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)\n",
    "        outp = self.out(self.out_drop(outp[:,0]))\n",
    "        return hid, outp\n",
    "        \n",
    "    def forward(self, inp, targ=None):\n",
    "        bs, sl = inp.size()\n",
    "        hid,enc_out = self.encoder(bs, inp)\n",
    "        dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
    "        enc_att = self.enc_att(enc_out)\n",
    "        \n",
    "        res = []\n",
    "        for i in range(self.out_sl):\n",
    "            hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)\n",
    "            res.append(outp)\n",
    "            dec_inp = outp.max(1)[1]\n",
    "            if (dec_inp==self.pad_idx).all(): break\n",
    "            if (targ is not None) and (random.random()<self.pr_force):\n",
    "                if i>=targ.shape[1]: continue\n",
    "                dec_inp = targ[:,i]\n",
    "        return torch.stack(res, dim=1)\n",
    "\n",
    "    def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh)\n",
    "    def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30)\n",
    "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,\n",
    "                callback_fns=partial(TeacherForcing, end_epoch=30))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>seq2seq_acc</th>\n",
       "      <th>bleu</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>2.411852</td>\n",
       "      <td>3.832309</td>\n",
       "      <td>0.489762</td>\n",
       "      <td>0.266689</td>\n",
       "      <td>02:08</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>1.962503</td>\n",
       "      <td>4.193816</td>\n",
       "      <td>0.490516</td>\n",
       "      <td>0.398154</td>\n",
       "      <td>01:52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.705578</td>\n",
       "      <td>4.441123</td>\n",
       "      <td>0.456639</td>\n",
       "      <td>0.390446</td>\n",
       "      <td>01:51</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.562641</td>\n",
       "      <td>4.048814</td>\n",
       "      <td>0.472333</td>\n",
       "      <td>0.403023</td>\n",
       "      <td>01:49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.373088</td>\n",
       "      <td>4.048331</td>\n",
       "      <td>0.477947</td>\n",
       "      <td>0.413261</td>\n",
       "      <td>01:50</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(5, 3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# learn.save('5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Seq2SeqRNN_attn. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    }
   ],
   "source": [
    "learn.load('5');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preds_acts(learn, ds_type=DatasetType.Valid):\n",
    "    \"Same as `get_predictions` but also returns non-reconstructed activations\"\n",
    "    learn.model.eval()\n",
    "    ds = learn.data.train_ds\n",
    "    rxs,rys,rzs,xs,ys,zs = [],[],[],[],[],[] # 'r' == 'reconstructed'\n",
    "    with torch.no_grad():\n",
    "        for xb,yb in progress_bar(learn.dl(ds_type)):\n",
    "            out = learn.model(xb)\n",
    "            for x,y,z in zip(xb,yb,out):\n",
    "                rxs.append(ds.x.reconstruct(x))\n",
    "                rys.append(ds.y.reconstruct(y))\n",
    "                preds = z.argmax(1)\n",
    "                rzs.append(ds.y.reconstruct(preds))\n",
    "                for a,b in zip([xs,ys,zs],[x,y,z]): a.append(b)\n",
    "    return rxs,rys,rzs,xs,ys,zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='151' class='' max='151', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [151/151 00:27<00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "rxs,rys,rzs,xs,ys,zs = preds_acts(learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Text xxbos quelles sont les lacunes qui existent encore dans notre connaissance du travail autonome et sur lesquelles les recherches devraient se concentrer à l’avenir ?,\n",
       " Text xxbos what gaps remain in our knowledge of xxunk on which future research should focus ?,\n",
       " Text xxbos what gaps are needed in our work and what is the research of the work and what research will be in place to future ?)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx=701\n",
    "rx,ry,rz = rxs[idx],rys[idx],rzs[idx]\n",
    "x,y,z = xs[idx],ys[idx],zs[idx]\n",
    "rx,ry,rz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_topk(outp, k=5):\n",
    "    probs = F.softmax(outp,dim=-1)\n",
    "    vals,idxs = probs.topk(k, dim=-1)\n",
    "    return idxs[torch.randint(k, (1,))]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import choice\n",
    "\n",
    "def select_nucleus(outp, p=0.5):\n",
    "    probs = F.softmax(outp,dim=-1)\n",
    "    idxs = torch.argsort(probs, descending=True)\n",
    "    res,cumsum = [],0.\n",
    "    for idx in idxs:\n",
    "        res.append(idx)\n",
    "        cumsum += probs[idx]\n",
    "        if cumsum>p: return idxs.new_tensor([choice(res)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode(self, inp):\n",
    "    inp = inp[None]\n",
    "    bs, sl = inp.size()\n",
    "    hid,enc_out = self.encoder(bs, inp)\n",
    "    dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
    "    enc_att = self.enc_att(enc_out)\n",
    "\n",
    "    res = []\n",
    "    for i in range(self.out_sl):\n",
    "        hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)\n",
    "        dec_inp = select_nucleus(outp[0], p=0.3)\n",
    "#         dec_inp = select_topk(outp[0], k=2)\n",
    "        res.append(dec_inp)\n",
    "        if (dec_inp==self.pad_idx).all(): break\n",
    "    return torch.cat(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_with_decode(learn, x, y):\n",
    "    learn.model.eval()\n",
    "    ds = learn.data.train_ds\n",
    "    with torch.no_grad():\n",
    "        out = decode(learn.model, x)\n",
    "        rx = ds.x.reconstruct(x)\n",
    "        ry = ds.y.reconstruct(y)\n",
    "        rz = ds.y.reconstruct(out)\n",
    "    return rx,ry,rz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text xxbos what gaps are needed in our understanding of work and security and how research will need to be put in place ?"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rx,ry,rz = predict_with_decode(learn, x, y)\n",
    "rz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
